/**
 * @file model.c
 *
 * Copyright (c) Huawei Technologies Co., Ltd. 2023. All Rights reserved.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
 */
#include "acl/acl_mdl.h"
#include "acl/acl_base.h"
#include "framework/executor_c/ge_executor.h"
#include "model_desc_internal.h"
#include "log_inner.h"
#include "runtime/rt.h"
#include "model_config.h"

static void SetPartFromHandle(const aclmdlConfigHandle *handle, ModelData *data) {
  data->part.weightPtr = handle->exeOMInfo.weightPtr;
  data->part.weightSize = handle->exeOMDesc.weightSize;
  data->part.modelDescPtr = handle->exeOMInfo.modelDescPtr;
  data->part.modelDescSize = handle->exeOMDesc.modelDescSize;
  data->part.kernelPtr = handle->exeOMInfo.kernelPtr;
  data->part.kernelSize = handle->exeOMDesc.kernelSize;
  data->part.paramPtr = handle->exeOMInfo.kernelArgsPtr;
  data->part.paramSize = handle->exeOMDesc.kernelArgsSize;
  data->part.taskPtr = handle->exeOMInfo.staticTaskPtr;
  data->part.taskSize = handle->exeOMDesc.staticTaskSize;
  data->part.dynTaskPtr = handle->exeOMInfo.dynamicTaskPtr;
  data->part.dynTaskSize = handle->exeOMDesc.dynamicTaskSize;
}

static aclError ModelLoadFromFileWithMem(const aclmdlConfigHandle *handle, uint32_t *const modelId) {
  ModelData data = {0};
  data.priority = handle->priority;
  data.memType = handle->memType;
  char *modelPath = handle->loadPath;
  SetPartFromHandle(handle, &data);
  Status ret = LoadDataFromFile(modelPath, &data);
  if (ret != SUCCESS) {
    ACL_LOG_CALL_ERROR("load data from file[%s] failed, ret[%u]", modelPath, ret);
    return ret;
  }
  uint32_t id = 0U;
  ret = GeLoadModelFromData(&id, &data);
  if (ret != SUCCESS) {
      FreeModelData(&data);
      ACL_LOG_CALL_ERROR("load model from data failed, ret[%u]", ret);
      return ret;
  }

  *modelId = id;
  FreeModelData(&data);
  return ACL_SUCCESS;
}

static aclError ModelLoadFromMemWithMem(const aclmdlConfigHandle *handle, uint32_t *const modelId) {
  size_t modelSize = handle->mdlSize;
  void *model = handle->mdlAddr;
  if ((modelSize == 0U) || (model == NULL)) {
    ACL_LOG_INNER_ERROR("param[%zu] is invalid.", modelSize);
    return ACL_ERROR_INVALID_PARAM;
  }

  ModelData data;
  data.modelData = model;
  data.modelLen = (uint64_t)modelSize;
  data.priority = handle->priority;
  data.memType = handle->memType;
  data.fd = NULL;
  data.flag = NO_NEED_READ_FROM_FD;
  SetPartFromHandle(handle, &data);
  uint32_t id = 0U;
  Status ret = GeLoadModelFromData(&id, &data);
  if (ret != SUCCESS) {
      ACL_LOG_CALL_ERROR("load model from data failed, ret[%u].", ret);
      return ret;
  }

  *modelId = id;
  return ACL_SUCCESS;
}

static aclError SetExecHandle(const aclmdlExecConfigHandle *handle, aclrtStream stream, ExecHandleDesc *execDesc) {
  void *workPtr = NULL;
  size_t workSize = 0U;
  if (handle != NULL) {
    if (handle->workPtr != NULL) {
      workPtr = handle->workPtr;
      workSize = handle->workSize;
    }
  }
  if ((workPtr == NULL) && (stream != NULL)) {
    (void)rtStreamGetWorkspace(stream, &workPtr, &workSize);
  }
  if (workPtr == NULL) {
    ACL_LOG_ERROR("not set workSpacePtr.");
    return ACL_ERROR_INVALID_PARAM;
  }
  execDesc->stream = stream;
  execDesc->workPtr = workPtr;
  execDesc->workSize = workSize;
  if (handle != NULL) {
    execDesc->mpamId = handle->mpamId;
    execDesc->aicQos = handle->aicQos;
    execDesc->aicOst = handle->aicOst;
    execDesc->mecTimeThreshHold = handle->mecTimeThreshHold;
  }
  return ACL_SUCCESS;
}

static aclError ModelExecute(const uint32_t modelId, const aclmdlDataset *input, aclmdlDataset *output,
                             const bool sync, aclrtStream stream, const aclmdlExecConfigHandle *handle) {
  ACL_REQUIRES_NOT_NULL_WITH_INPUT_REPORT(input);
  ACL_REQUIRES_NOT_NULL_WITH_INPUT_REPORT(output);
  ExecHandleDesc execDesc = {0};
  aclError err = SetExecHandle(handle, stream, &execDesc);
  if (err != ACL_SUCCESS) {
    return err;
  }

  const char *syncMode = (sync == true) ? "sync" : "async";
  const Status ret = ExecModel(modelId, &execDesc, sync, (const InputData *)input, (OutputData *)output);
  if (ret != SUCCESS) {
    ACL_LOG_CALL_ERROR("%s execute model failed, ret[%u], modelId[%u]", syncMode, ret, modelId);
    return ret;
  }
  ACL_LOG_INFO("success execute, modelId[%u] mode[%d]", modelId, sync);
  return ACL_SUCCESS;
}

aclError aclmdlLoadWithConfig(const aclmdlConfigHandle *handle, uint32_t *modelId) {
  if ((handle == NULL) || (modelId == NULL)) {
    ACL_LOG_ERROR("param must not be null.");
    return ACL_ERROR_INVALID_PARAM;
  }
  if (!CheckMdlConfigHandle(handle)) {
    return ACL_ERROR_INVALID_PARAM;
  }
  aclError ret;
  switch (handle->mdlLoadType) {
    case ACL_MDL_LOAD_FROM_FILE:
    case ACL_MDL_LOAD_FROM_FILE_WITH_MEM: {
      ret = ModelLoadFromFileWithMem(handle, modelId);
      break;
    }
    case ACL_MDL_LOAD_FROM_MEM:
    case ACL_MDL_LOAD_FROM_MEM_WITH_MEM: {
      ret = ModelLoadFromMemWithMem(handle, modelId);
      break;
    }
    default:
      ACL_LOG_INNER_ERROR("load type[%zu] is invalid, it should be in [%d, %d]",
          handle->mdlLoadType, ACL_MDL_LOAD_FROM_FILE, ACL_MDL_LOAD_FROM_MEM_WITH_MEM);
      return ACL_ERROR_INVALID_PARAM;
  }
  if (ret != ACL_SUCCESS) {
    ACL_LOG_INNER_ERROR("model load failed, type[%zu].", handle->mdlLoadType);
    return ret;
  }
  return ACL_SUCCESS;
}

aclError aclmdlExecuteV2(uint32_t modelId, const aclmdlDataset *input, aclmdlDataset *output, aclrtStream stream,
                         const aclmdlExecConfigHandle *handle) {
  return ModelExecute(modelId, input, output, true, stream, handle);
}

aclError aclmdlExecuteAsyncV2(uint32_t modelId, const aclmdlDataset *input, aclmdlDataset *output, aclrtStream stream,
                              const aclmdlExecConfigHandle *handle) {
  return ModelExecute(modelId, input, output, false, stream, handle);
}

aclError aclmdlUnload(uint32_t modelId) {
  const Status ret = UnloadModel(modelId);
  if (ret != SUCCESS) {
    ACL_LOG_CALL_ERROR("unload model failed, ret[%u], modelId[%u]", ret, modelId);
    return ret;
  }
  return ACL_SUCCESS;
}

aclError aclmdlQuerySize(const char *fileName, size_t *workSize, size_t *weightSize) {
  if ((fileName == NULL) || (workSize == NULL) || (weightSize == NULL)) {
    ACL_LOG_ERROR("param must not be null.");
    return ACL_ERROR_INVALID_PARAM;
  }
  const Status ret = GetMemAndWeightSize(fileName, workSize, weightSize);
  if (ret != SUCCESS) {
    ACL_LOG_CALL_ERROR("query size failed, ge ret[%u]", ret);
    return ret;
  }
  return ACL_SUCCESS;
}

aclError aclmdlQueryExeOMDesc(const char *fileName, aclmdlExeOMDesc *mdlPartitionSize) {
  if ((fileName == NULL) || (mdlPartitionSize == NULL)) {
    ACL_LOG_ERROR("param must not be null.");
    return ACL_ERROR_INVALID_PARAM;
  }
  const Status ret = GetPartitionSize(fileName, (GePartitionSize *)mdlPartitionSize);
  if (ret != SUCCESS) {
    ACL_LOG_CALL_ERROR("query partition size failed, ge ret[%u]", ret);
    return ret;
  }
  return ACL_SUCCESS;
}

aclmdlDataset *aclmdlCreateDataset() {
  DataSet *dataset = (DataSet *)mmMalloc(sizeof(DataSet));
  if (dataset == NULL) {
    return NULL;
  }
  InitVector(&dataset->blobs, sizeof(DataBlob));
  dataset->io_addr = NULL;
  dataset->ioa_size = 0U;
  return (aclmdlDataset *)dataset;
}

aclError aclmdlDestroyDataset(const aclmdlDataset *dataset) {
  ACL_REQUIRES_NOT_NULL_WITH_INPUT_REPORT(dataset);
  DataSet *geDataSet = (DataSet *)dataset;
  DeInitVector(&geDataSet->blobs);
  if (geDataSet->io_addr != NULL) {
    (void)rtFree(geDataSet->io_addr);
  }
  geDataSet->ioa_size = 0U;
  geDataSet->io_addr = NULL;
  mmFree(geDataSet);
  return ACL_SUCCESS;
}

aclError aclmdlAddDatasetBuffer(aclmdlDataset *dataset, aclDataBuffer *dataBuffer) {
  if ((dataset == NULL) || (dataBuffer == NULL)) {
    ACL_LOG_ERROR("invalid input args.");
    return ACL_ERROR_INVALID_PARAM;
  }
  DataBuffer *geDatabuffer = (DataBuffer *)dataBuffer;
  DataSet *geDataSet = (DataSet *)dataset;
  DataBlob blob;
  blob.dataBuffer = geDatabuffer;
  if (EmplaceBackVector(&geDataSet->blobs, &blob) == NULL) {
    return ACL_ERROR_INTERNAL_ERROR;
  }
  return ACL_SUCCESS;
}

size_t aclmdlGetDatasetNumBuffers(const aclmdlDataset *dataset) {
  const DataSet *geDataSet = (const DataSet *)dataset;
  if (geDataSet == NULL) {
    ACL_LOG_ERROR("invalid input args.");
    REPORT_INPUT_ERROR(INVALID_NULL_POINTER_MSG, ARRAY("param"), ARRAY("dataset"));
    return 0U;
  }
  return VectorSize(&geDataSet->blobs);
}

aclDataBuffer *aclmdlGetDatasetBuffer(const aclmdlDataset *dataset, size_t index) {
  DataSet *geDataSet = (DataSet *)dataset;
  if (geDataSet == NULL) {
    ACL_LOG_ERROR("invalid input args.");
    return NULL;
  }
  DataBlob *blob = VectorAt(&geDataSet->blobs, index);
  if (blob == NULL) {
    ACL_LOG_ERROR("dataset has no data.");
    return NULL;
  }
  return (aclDataBuffer *)blob->dataBuffer;
}